# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# This file is part of ByteQC.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from byteqc.cuobc.scf.hf import _VHFOpt
from time import time
import cupy
import numpy
from byteqc.cuobc import scf
from pyscf import gto
from pyscf import scf as cscf

nwater = 32
for nwater in [1]:
    mol = gto.M()
    mol.atom = '''
    O 5.07073593 5.98522806 1.99074030
    H 5.47266340 5.41048479 1.37726617
    H 4.29230547 5.50148201 2.38133121
    O 6.96930218 3.65419388 7.69751453
    H 7.76904678 3.41141653 8.15803623
    H 7.19632435 3.77704048 6.80754805
    O 7.07114363 0.38541618 6.31931973
    H 6.51683569 0.89463365 5.71413946
    H 7.89046621 0.49394438 5.86833906
    O 9.21707249 0.60753405 0.85845739
    H 9.11190128 1.36154389 0.33284441
    H 8.61856461 0.00654888 0.42679650
    O 2.80207729 8.11080360 8.04087448
    H 3.63785648 7.76187134 7.74992561
    H 3.04515648 8.54183960 8.83652496
    O 6.30238819 8.51380348 0.28597444
    H 6.60682964 8.87906742 1.14078999
    H 6.74352312 7.68866539 0.01975211
    O 3.06945777 3.75312948 6.44377136
    H 2.18450260 4.12506390 6.16728687
    H 3.05599904 2.81985664 6.20426559
    O 4.92952299 5.59797001 7.71687841
    H 5.66505432 5.16413927 8.15059280
    H 4.37021732 4.88033485 7.37394810
    O 3.00636864 2.28544617 3.28489971
    H 2.79496050 3.22795010 3.19227219
    H 2.86957741 1.92318964 2.46056271
    O 5.43206596 1.29600477 9.01845360
    H 5.95717478 1.65866745 8.28544903
    H 5.76271439 0.39199132 9.23204327
    O 5.73688364 2.45922589 3.73790503
    H 4.76279688 2.33512568 3.59089351
    H 5.71681070 3.20236135 4.33737850
    O 0.43886620 7.24681759 6.93701649
    H 1.35015965 7.35265350 7.26723862
    H 0.36180931 6.28662968 6.90553951
    O 2.97921991 1.60385466 0.31199196
    H 3.82191563 1.62987852 -0.06028282
    H 2.50401211 2.41151643 0.03694036
    O 6.17979431 9.74389553 2.92950201
    H 5.39901018 9.41284370 3.40269041
    H 6.28733921 10.60690498 3.31242228
    O 7.81284142 6.60153103 8.99971962
    H 8.57145596 6.90485859 8.46935654
    H 8.14466572 6.17615747 9.85374928
    O 2.49865532 5.32045507 2.67103815
    H 2.25583696 5.87605810 1.85574365
    H 2.32352066 5.87275410 3.53010368
    O 7.31650305 6.90478277 3.30298328
    H 7.31806421 7.83650684 3.40219688
    H 6.50129271 6.76037025 2.74671173
    O 8.81621552 3.30051851 4.64679909
    H 9.30960655 2.96997213 3.91736698
    H 8.21753120 4.02485323 4.31665134
    O 0.19844723 0.21932974 6.09385014
    H 1.09746647 0.43695515 5.93445396
    H 0.17539406 -0.69561619 6.37414455
    O 2.07528901 3.93826103 9.01441956
    H 2.44687462 3.81776476 8.16495132
    H 2.02195740 4.90535402 9.15301132
    O 1.82536161 7.35551262 4.49379349
    H 1.50180137 7.42050362 5.36343765
    H 1.16098309 7.88878393 3.98919296
    O 0.24929428 4.54504395 6.72563934
    H -0.14175034 4.10964870 5.91322088
    H 0.29467010 3.78147531 7.32770681
    O 8.53362751 5.55051613 1.47352052
    H 8.20640278 6.24974823 2.08278370
    H 7.86576223 4.90864420 1.61348271
    O 4.05105782 9.05945301 4.29349422
    H 4.50658369 8.61413479 5.01351023
    H 3.37354660 8.40744495 4.11785507
    O 2.91260624 0.95893687 6.15000582
    H 3.18501997 0.38297644 6.85079622
    H 3.26618695 0.51552308 5.33614063
    O 3.33805180 8.83594131 1.40738785
    H 3.42589712 9.76232147 1.29618788
    H 4.20596886 8.52754688 1.06519771
    O 9.64161205 2.24009204 8.21248245
    H 10.44588184 2.42511582 8.71880341
    H 9.77668190 1.55341518 7.57285404
    O 6.24836349 4.95267439 5.06015348
    H 6.45660734 5.82238102 4.75356150
    H 5.67656326 5.00484085 5.87013102
    O 5.23794460 8.03899288 6.54741335
    H 5.94747353 8.66842175 6.75902081
    H 5.56287575 7.19080162 6.80881691
    O 6.24564791 3.63376021 1.14055872
    H 6.10870647 2.99982285 1.90240943
    H 5.97900295 3.09268880 0.40766400
    O 0.66831440 8.69901180 2.04291368
    H 1.61453676 8.92202663 1.95631087
    H 0.23899707 9.51721478 1.69478559
    O 0.97948104 6.47605038 0.23600960
    H 0.20574865 6.06915855 0.63948917
    H 0.85356140 7.25910330 0.75218296'''
    mol.atom = '\n'.join(mol.atom.split('\n')[1:nwater * 3 + 1])
    mol.basis = 'ccpvdz'
    mol.build()

    nao = mol.nao
    print("Nwater:", nwater, "Nao:", nao)
    numpy.random.seed(2341)
    dm = numpy.random.rand(nao, nao)
    dm = cupy.asarray(dm)
    dm = dm + dm.T
    print("Start to build jk")
    vhfopt = _VHFOpt(mol, 'int2e').build(gpus=1)
    cupy.cuda.Device().synchronize()
    start = time()
    for i in range(1):
        scf.RHF(mol).kernel()
    cupy.cuda.Device().synchronize()
    print("GPU time", time() - start)
    dm = dm.get()
    cscf.RHF(mol).kernel()
